{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX 概述\n",
    "\n",
    "源代码：[ONNX](https://github.com/onnx/onnx) & [ONNX with Python](https://onnx.ai/onnx/intro/python.html)\n",
    "\n",
    "```{note}\n",
    "ONNX 是一种开放格式，旨在表示机器学习模型。ONNX 定义了一组通用的算子——机器学习和深度学习模型的基本构建块——以及一种通用的文件格式，以便 AI 开发人员能够使用多种框架、工具、运行时和编译器来处理模型。\n",
    "```\n",
    "\n",
    "```{topic} 技术设计\n",
    "ONNX 提供了一种可扩展的计算图模型定义，以及内置算子和标准数据类型的定义。\n",
    "\n",
    "每个计算数据流图都结构化为一个节点列表，这些节点形成一个无环图。节点有一个或多个输入和一个或多个输出。每个节点都是对一个算子的调用。该计算图还包含元数据，以帮助记录其用途、作者等信息。\n",
    "\n",
    "算子在计算图外部实现，但内置算子集是跨框架可移植的。支持 ONNX 的每个框架都会提供对这些算子在适用数据类型上的实现。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Python 接口\n",
    "\n",
    "参考：[onnx python](https://onnx.ai/onnx/intro/python.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ONNX 是强类型的。必须为函数的输入和输出定义形状和类型。下面介绍一些常用的 [make function](https://onnx.ai/onnx/api/helper.html#l-onnx-make-function)：\n",
    "\n",
    "- `make_tensor_value_info`：声明给定形状和类型的变量（输入或输出）。\n",
    "- `make_node`: 创建由运算（算子类型）、输入和输出定义的节点。\n",
    "- `make_graph`: 用于使用前两个函数创建的 ONNX graph 对象。\n",
    "- `make_model`: 将 graph 和附加元数据合并在一起。\n",
    "\n",
    "\n",
    "在整个创建过程中，需要为 graph 中每个节点的每个输入和输出命名。graph 的输入和输出由 onnx 对象定义，字符串用于引用中间结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnx import TensorProto\n",
    "from onnx.helper import (\n",
    "    make_model, make_node, make_graph,\n",
    "    make_tensor_value_info\n",
    ")\n",
    "# from onnx.checker import check_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建输入变量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m\n",
      "\u001b[0mmake_tensor_value_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0melem_type\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mshape\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdoc_string\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mshape_denotation\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mValueInfoProto\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m Makes a ValueInfoProto based on the data type and shape.\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnx/helper.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "make_tensor_value_info?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])\n",
    "A = make_tensor_value_info('A', TensorProto.FLOAT, [None, None])\n",
    "B = make_tensor_value_info('B', TensorProto.FLOAT, [None, None])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看 `X`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "name: \"X\"\n",
       "type {\n",
       "  tensor_type {\n",
       "    elem_type: 1\n",
       "    shape {\n",
       "      dim {\n",
       "      }\n",
       "      dim {\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建输出变量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将上述变量组织为计算图的节点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "node1 = make_node('MatMul', ['X', 'A'], ['XA'])\n",
    "node2 = make_node('Add', ['XA', 'B'], ['Y'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看 `node1`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "input: \"X\"\n",
       "input: \"A\"\n",
       "output: \"XA\"\n",
       "op_type: \"MatMul\""
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将节点组织为计算图："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m\n",
      "\u001b[0mmake_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mnodes\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNodeProto\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0minputs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mValueInfoProto\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0moutputs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mValueInfoProto\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0minitializer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensorProto\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdoc_string\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mvalue_info\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mValueInfoProto\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msparse_initializer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSparseTensorProto\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGraphProto\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m\n",
      "Construct a GraphProto\n",
      "\n",
      "Args:\n",
      "    nodes: list of NodeProto\n",
      "    name (string): graph name\n",
      "    inputs: list of ValueInfoProto\n",
      "    outputs: list of ValueInfoProto\n",
      "    initializer: list of TensorProto\n",
      "    doc_string (string): graph documentation\n",
      "    value_info: list of ValueInfoProto\n",
      "    sparse_initializer: list of SparseTensorProto\n",
      "Returns:\n",
      "    GraphProto\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnx/helper.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "make_graph?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = make_graph(\n",
    "    [node1, node2],  # nodes\n",
    "    'lr',  # a name\n",
    "    [X, A, B],  # inputs\n",
    "    [Y] # outputs\n",
    ")  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "node {\n",
       "  input: \"X\"\n",
       "  input: \"A\"\n",
       "  output: \"XA\"\n",
       "  op_type: \"MatMul\"\n",
       "}\n",
       "node {\n",
       "  input: \"XA\"\n",
       "  input: \"B\"\n",
       "  output: \"Y\"\n",
       "  op_type: \"Add\"\n",
       "}\n",
       "name: \"lr\"\n",
       "input {\n",
       "  name: \"X\"\n",
       "  type {\n",
       "    tensor_type {\n",
       "      elem_type: 1\n",
       "      shape {\n",
       "        dim {\n",
       "        }\n",
       "        dim {\n",
       "        }\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "}\n",
       "input {\n",
       "  name: \"A\"\n",
       "  type {\n",
       "    tensor_type {\n",
       "      elem_type: 1\n",
       "      shape {\n",
       "        dim {\n",
       "        }\n",
       "        dim {\n",
       "        }\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "}\n",
       "input {\n",
       "  name: \"B\"\n",
       "  type {\n",
       "    tensor_type {\n",
       "      elem_type: 1\n",
       "      shape {\n",
       "        dim {\n",
       "        }\n",
       "        dim {\n",
       "        }\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "}\n",
       "output {\n",
       "  name: \"Y\"\n",
       "  type {\n",
       "    tensor_type {\n",
       "      elem_type: 1\n",
       "      shape {\n",
       "        dim {\n",
       "        }\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将计算图变换为模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = make_model(graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ir_version: 10\n",
       "graph {\n",
       "  node {\n",
       "    input: \"X\"\n",
       "    input: \"A\"\n",
       "    output: \"XA\"\n",
       "    op_type: \"MatMul\"\n",
       "  }\n",
       "  node {\n",
       "    input: \"XA\"\n",
       "    input: \"B\"\n",
       "    output: \"Y\"\n",
       "    op_type: \"Add\"\n",
       "  }\n",
       "  name: \"lr\"\n",
       "  input {\n",
       "    name: \"X\"\n",
       "    type {\n",
       "      tensor_type {\n",
       "        elem_type: 1\n",
       "        shape {\n",
       "          dim {\n",
       "          }\n",
       "          dim {\n",
       "          }\n",
       "        }\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "  input {\n",
       "    name: \"A\"\n",
       "    type {\n",
       "      tensor_type {\n",
       "        elem_type: 1\n",
       "        shape {\n",
       "          dim {\n",
       "          }\n",
       "          dim {\n",
       "          }\n",
       "        }\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "  input {\n",
       "    name: \"B\"\n",
       "    type {\n",
       "      tensor_type {\n",
       "        elem_type: 1\n",
       "        shape {\n",
       "          dim {\n",
       "          }\n",
       "          dim {\n",
       "          }\n",
       "        }\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "  output {\n",
       "    name: \"Y\"\n",
       "    type {\n",
       "      tensor_type {\n",
       "        elem_type: 1\n",
       "        shape {\n",
       "          dim {\n",
       "          }\n",
       "        }\n",
       "      }\n",
       "    }\n",
       "  }\n",
       "}\n",
       "opset_import {\n",
       "  version: 21\n",
       "}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "onnx_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "空的形状（{data}`None`）意味着任何形状。\n",
    "```\n",
    "\n",
    "## 访问 ONNX graph\n",
    "\n",
    "ONNX graph 也可以通过查看计算图中每个对象的字段来检查。\n",
    "\n",
    "查看输入列表："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[name: \"X\"\n",
      "type {\n",
      "  tensor_type {\n",
      "    elem_type: 1\n",
      "    shape {\n",
      "      dim {\n",
      "      }\n",
      "      dim {\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      ", name: \"A\"\n",
      "type {\n",
      "  tensor_type {\n",
      "    elem_type: 1\n",
      "    shape {\n",
      "      dim {\n",
      "      }\n",
      "      dim {\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      ", name: \"B\"\n",
      "type {\n",
      "  tensor_type {\n",
      "    elem_type: 1\n",
      "    shape {\n",
      "      dim {\n",
      "      }\n",
      "      dim {\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "print(onnx_model.graph.input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "更优雅的打印输入信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shape2tuple(shape):\n",
    "    return tuple(getattr(d, 'dim_value', 0) for d in shape.dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name='X' dtype=1 shape=(0, 0)\n",
      "name='A' dtype=1 shape=(0, 0)\n",
      "name='B' dtype=1 shape=(0, 0)\n"
     ]
    }
   ],
   "source": [
    "for obj in onnx_model.graph.input:\n",
    "    print(\"name=%r dtype=%r shape=%r\" % (\n",
    "        obj.name, obj.type.tensor_type.elem_type,\n",
    "        shape2tuple(obj.type.tensor_type.shape)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "同样可以查看输出信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name='Y' dtype=1 shape=(0,)\n"
     ]
    }
   ],
   "source": [
    "for obj in onnx_model.graph.output:\n",
    "    print(\"name=%r dtype=%r shape=%r\" % (\n",
    "        obj.name, obj.type.tensor_type.elem_type,\n",
    "        shape2tuple(obj.type.tensor_type.shape)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看节点信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name='' type='MatMul' input=['X', 'A'] output=['XA']\n",
      "name='' type='Add' input=['XA', 'B'] output=['Y']\n"
     ]
    }
   ],
   "source": [
    "for node in onnx_model.graph.node:\n",
    "    print(\"name=%r type=%r input=%r output=%r\" % (\n",
    "        node.name, node.op_type, node.input, node.output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX 序列化与反序列化\n",
    "\n",
    "onnx 中的每个对象（参见 [Protos](https://onnx.ai/onnx/api/classes.html#l-onnx-classes)）都可以用 `SerializeToString` 方法序列化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p .temp\n",
    "with open(\".temp/linear_regression.onnx\", \"wb\") as f:\n",
    "    f.write(onnx_model.SerializeToString())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "反序列化，加载序列化的模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "\n",
    "with open(\".temp/linear_regression.onnx\", \"rb\") as f:\n",
    "    onnx_model = onnx.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据也可以序列化："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "<class 'onnx.onnx_ml_pb2.TensorProto'>\n",
      "<class 'bytes'>\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from onnx.numpy_helper import from_array\n",
    "\n",
    "numpy_tensor = np.array([0, 1, 4, 5, 3], dtype=np.float32)\n",
    "print(type(numpy_tensor))\n",
    "\n",
    "onnx_tensor = from_array(numpy_tensor)\n",
    "print(type(onnx_tensor))\n",
    "\n",
    "serialized_tensor = onnx_tensor.SerializeToString()\n",
    "print(type(serialized_tensor))\n",
    "\n",
    "with open(\".temp/saved_tensor.pb\", \"wb\") as f:\n",
    "    f.write(serialized_tensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "反序列化数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'bytes'>\n",
      "<class 'onnx.onnx_ml_pb2.TensorProto'>\n",
      "[0. 1. 4. 5. 3.]\n"
     ]
    }
   ],
   "source": [
    "from onnx import TensorProto\n",
    "from onnx.numpy_helper import to_array\n",
    "\n",
    "with open(\".temp/saved_tensor.pb\", \"rb\") as f:\n",
    "    serialized_tensor = f.read()\n",
    "print(type(serialized_tensor))\n",
    "\n",
    "onnx_tensor = TensorProto()\n",
    "onnx_tensor.ParseFromString(serialized_tensor)\n",
    "print(type(onnx_tensor))\n",
    "\n",
    "numpy_tensor = to_array(onnx_tensor)\n",
    "print(numpy_tensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可以使用便捷函数 `load_tensor_from_string`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'onnx.onnx_ml_pb2.TensorProto'>\n"
     ]
    }
   ],
   "source": [
    "from onnx import load_tensor_from_string\n",
    "\n",
    "with open(\"saved_tensor.pb\", \"rb\") as f:\n",
    "    serialized = f.read()\n",
    "proto = load_tensor_from_string(serialized)\n",
    "print(type(proto))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX 初始化器与默认值\n",
    "\n",
    "之前的模型假设线性回归的系数也是模型的输入。那不太方便。它们应该作为常量或初始化项成为模型本身的一部分，以遵循 onnx 语义。下一个示例修改了前一个示例，将输入 `A` 和 `B` 更改为初始化式。（参见 [array](https://onnx.ai/onnx/api/numpy_helper.html#l-numpy-helper-onnx-array)）。\n",
    "\n",
    "- `onnx.numpy_helper.to_array`: 将 ONNX 转换为 NumPy 数组。\n",
    "- `onnx.numpy_helper.from_array`: 将 NumPy 数组转换为 ONNX。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from onnx import numpy_helper, TensorProto\n",
    "from onnx.helper import (\n",
    "    make_model, make_node, make_graph,\n",
    "    make_tensor_value_info)\n",
    "from onnx.checker import check_model\n",
    "\n",
    "# initializers\n",
    "value = np.array([0.5, -0.6], dtype=np.float32)\n",
    "A = numpy_helper.from_array(value, name='A')\n",
    "\n",
    "value = np.array([0.4], dtype=np.float32)\n",
    "C = numpy_helper.from_array(value, name='C')\n",
    "\n",
    "X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])\n",
    "Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None])\n",
    "node1 = make_node('MatMul', ['X', 'A'], ['AX'])\n",
    "node2 = make_node('Add', ['AX', 'C'], ['Y'])\n",
    "graph = make_graph([node1, node2], 'lr', [X], [Y], [A, C])\n",
    "onnx_model = make_model(graph)\n",
    "check_model(onnx_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看初始化值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dims: 2\n",
      "data_type: 1\n",
      "name: \"A\"\n",
      "raw_data: \"\\000\\000\\000?\\232\\231\\031\\277\"\n",
      "\n",
      "dims: 1\n",
      "data_type: 1\n",
      "name: \"C\"\n",
      "raw_data: \"\\315\\314\\314>\"\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for init in onnx_model.graph.initializer:\n",
    "    print(init)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX 节点属性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ir_version: 10\n",
      "graph {\n",
      "  node {\n",
      "    input: \"A\"\n",
      "    output: \"tA\"\n",
      "    op_type: \"Transpose\"\n",
      "    attribute {\n",
      "      name: \"perm\"\n",
      "      ints: 1\n",
      "      ints: 0\n",
      "      type: INTS\n",
      "    }\n",
      "  }\n",
      "  node {\n",
      "    input: \"X\"\n",
      "    input: \"tA\"\n",
      "    output: \"XA\"\n",
      "    op_type: \"MatMul\"\n",
      "  }\n",
      "  node {\n",
      "    input: \"XA\"\n",
      "    input: \"B\"\n",
      "    output: \"Y\"\n",
      "    op_type: \"Add\"\n",
      "  }\n",
      "  name: \"lr\"\n",
      "  input {\n",
      "    name: \"X\"\n",
      "    type {\n",
      "      tensor_type {\n",
      "        elem_type: 1\n",
      "        shape {\n",
      "          dim {\n",
      "          }\n",
      "          dim {\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  input {\n",
      "    name: \"A\"\n",
      "    type {\n",
      "      tensor_type {\n",
      "        elem_type: 1\n",
      "        shape {\n",
      "          dim {\n",
      "          }\n",
      "          dim {\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  input {\n",
      "    name: \"B\"\n",
      "    type {\n",
      "      tensor_type {\n",
      "        elem_type: 1\n",
      "        shape {\n",
      "          dim {\n",
      "          }\n",
      "          dim {\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  output {\n",
      "    name: \"Y\"\n",
      "    type {\n",
      "      tensor_type {\n",
      "        elem_type: 1\n",
      "        shape {\n",
      "          dim {\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "opset_import {\n",
      "  version: 21\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from onnx import TensorProto\n",
    "from onnx.helper import (\n",
    "    make_model, make_node, make_graph,\n",
    "    make_tensor_value_info)\n",
    "from onnx.checker import check_model\n",
    "\n",
    "# unchanged\n",
    "X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])\n",
    "A = make_tensor_value_info('A', TensorProto.FLOAT, [None, None])\n",
    "B = make_tensor_value_info('B', TensorProto.FLOAT, [None, None])\n",
    "Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None])\n",
    "\n",
    "# 添加属性\n",
    "node_transpose = make_node('Transpose', ['A'], ['tA'], perm=[1, 0])\n",
    "\n",
    "# unchanged except A is replaced by tA\n",
    "node1 = make_node('MatMul', ['X', 'tA'], ['XA'])\n",
    "node2 = make_node('Add', ['XA', 'B'], ['Y'])\n",
    "\n",
    "# node_transpose is added to the list\n",
    "graph = make_graph([node_transpose, node1, node2],\n",
    "                   'lr', [X, A, B], [Y])\n",
    "onnx_model = make_model(graph)\n",
    "check_model(onnx_model)\n",
    "\n",
    "# the work is done, let's display it...\n",
    "print(onnx_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX 评估与运行时\n",
    "\n",
    "完整 API 的描述见 [onnx.reference](https://onnx.ai/onnx/api/reference.html#l-reference-implementation)。它接受一个模型(ModelProto，文件名，…)。方法 `run` 返回字典中指定的一组给定输入的输出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[array([[-1.8057449],\n",
      "       [-2.0268912],\n",
      "       [-1.369731 ],\n",
      "       [-1.8334708]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from onnx import numpy_helper, TensorProto\n",
    "from onnx.helper import (\n",
    "    make_model, make_node, set_model_props, make_tensor,\n",
    "    make_graph, make_tensor_value_info)\n",
    "from onnx.checker import check_model\n",
    "from onnx.reference import ReferenceEvaluator\n",
    "\n",
    "X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])\n",
    "A = make_tensor_value_info('A', TensorProto.FLOAT, [None, None])\n",
    "B = make_tensor_value_info('B', TensorProto.FLOAT, [None, None])\n",
    "Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None])\n",
    "node1 = make_node('MatMul', ['X', 'A'], ['XA'])\n",
    "node2 = make_node('Add', ['XA', 'B'], ['Y'])\n",
    "graph = make_graph([node1, node2], 'lr', [X, A, B], [Y])\n",
    "onnx_model = make_model(graph)\n",
    "check_model(onnx_model)\n",
    "\n",
    "sess = ReferenceEvaluator(onnx_model)\n",
    "\n",
    "x = numpy.random.randn(4, 2).astype(numpy.float32)\n",
    "a = numpy.random.randn(2, 1).astype(numpy.float32)\n",
    "b = numpy.random.randn(1, 1).astype(numpy.float32)\n",
    "feeds = {'X': x, 'A': a, 'B': b}\n",
    "\n",
    "print(sess.run(None, feeds))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ONNX 评估节点\n",
    "\n",
    "评估器还可以评估简单的节点，以检查算子在特定输入上的行为。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[array([[1., 0.],\n",
      "       [0., 1.],\n",
      "       [0., 0.],\n",
      "       [0., 0.]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from onnx import numpy_helper, TensorProto\n",
    "from onnx.helper import make_node\n",
    "\n",
    "from onnx.reference import ReferenceEvaluator\n",
    "\n",
    "node = make_node('EyeLike', ['X'], ['Y'])\n",
    "\n",
    "sess = ReferenceEvaluator(node)\n",
    "\n",
    "x = numpy.random.randn(4, 2).astype(numpy.float32)\n",
    "feeds = {'X': x}\n",
    "\n",
    "print(sess.run(None, feeds))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "类似的代码也可以在 `GraphProto` 或 `FunctionProto` 上工作。\n",
    "\n",
    "### ONNX 逐步评估\n",
    "\n",
    "转换库接受使用机器学习框架（如 PyTorch、scikit-learn 等）训练的现有模型，并将其转换为 ONNX graph。复杂的模型通常不会在第一次尝试中工作，查看中间结果可能有助于找到未正确转换的部分。参数详细显示有关中间结果的信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "------ verbose=1\n",
      "\n",
      "[array([[ 0.9199427 ],\n",
      "       [ 0.3668961 ],\n",
      "       [ 0.04145484],\n",
      "       [-0.4361933 ]], dtype=float32)]\n",
      "\n",
      "------ verbose=2\n",
      "\n",
      "MatMul(X, A) -> XA\n",
      "Add(XA, B) -> Y\n",
      "[array([[3.8724718],\n",
      "       [4.0825634],\n",
      "       [0.4058497],\n",
      "       [2.005159 ]], dtype=float32)]\n",
      "\n",
      "------ verbose=3\n",
      "\n",
      " +I X: float32:(4, 2) in [-1.1875110864639282, 1.6281424760818481]\n",
      " +I A: float32:(2, 1) in [-1.774684190750122, -1.7485564947128296]\n",
      " +I B: float32:(1, 1) in [0.5859348177909851, 0.5859348177909851]\n",
      "MatMul(X, A) -> XA\n",
      " + XA: float32:(4, 1) in [-2.718858480453491, 3.8624463081359863]\n",
      "Add(XA, B) -> Y\n",
      " + Y: float32:(4, 1) in [-2.1329236030578613, 4.448380947113037]\n",
      "[array([[ 1.2496312 ],\n",
      "       [ 4.448381  ],\n",
      "       [-2.1329236 ],\n",
      "       [ 0.39690298]], dtype=float32)]\n",
      "\n",
      "------ verbose=4\n",
      "\n",
      " +I X: float32:(4, 2):1.1294219493865967,0.30637702345848083,-0.029152901843190193,0.06715787202119827,0.20514409244060516...\n",
      " +I A: float32:(2, 1):[-0.08160115778446198, 1.6542348861694336]\n",
      " +I B: float32:(1, 1):[-0.9010018110275269]\n",
      "MatMul(X, A) -> XA\n",
      " + XA: float32:(4, 1):[0.4146574139595032, 0.1134738028049469, 2.200101137161255, -0.3942927122116089]\n",
      "Add(XA, B) -> Y\n",
      " + Y: float32:(4, 1):[-0.4863443970680237, -0.7875280380249023, 1.299099326133728, -1.2952945232391357]\n",
      "[array([[-0.4863444 ],\n",
      "       [-0.78752804],\n",
      "       [ 1.2990993 ],\n",
      "       [-1.2952945 ]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from onnx import numpy_helper, TensorProto\n",
    "from onnx.helper import (\n",
    "    make_model, make_node, set_model_props, make_tensor,\n",
    "    make_graph, make_tensor_value_info)\n",
    "from onnx.checker import check_model\n",
    "from onnx.reference import ReferenceEvaluator\n",
    "\n",
    "X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])\n",
    "A = make_tensor_value_info('A', TensorProto.FLOAT, [None, None])\n",
    "B = make_tensor_value_info('B', TensorProto.FLOAT, [None, None])\n",
    "Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None])\n",
    "node1 = make_node('MatMul', ['X', 'A'], ['XA'])\n",
    "node2 = make_node('Add', ['XA', 'B'], ['Y'])\n",
    "graph = make_graph([node1, node2], 'lr', [X, A, B], [Y])\n",
    "onnx_model = make_model(graph)\n",
    "check_model(onnx_model)\n",
    "\n",
    "for verbose in [1, 2, 3, 4]:\n",
    "    print()\n",
    "    print(f\"------ verbose={verbose}\")\n",
    "    print()\n",
    "    sess = ReferenceEvaluator(onnx_model, verbose=verbose)\n",
    "\n",
    "    x = numpy.random.randn(4, 2).astype(numpy.float32)\n",
    "    a = numpy.random.randn(2, 1).astype(numpy.float32)\n",
    "    b = numpy.random.randn(1, 1).astype(numpy.float32)\n",
    "    feeds = {'X': x, 'A': a, 'B': b}\n",
    "\n",
    "    print(sess.run(None, feeds))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ONNX 评估自定义节点\n",
    "以下示例仍然实现了线性回归，但在 `A` 中添加了单位矩阵：\n",
    "\n",
    "$$\n",
    "Y = X(A + I) + B\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EyeLike(A) -> Eye\n",
      "Add(A, Eye) -> A1\n",
      "MatMul(X, A1) -> XA1\n",
      "Add(XA1, B) -> Y\n",
      "[array([[ 1.5872786 ,  2.8542266 ],\n",
      "       [ 0.64192116, -1.225003  ],\n",
      "       [-1.0039632 ,  1.1880492 ],\n",
      "       [-0.31344515,  0.10079634]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from onnx import numpy_helper, TensorProto\n",
    "from onnx.helper import (\n",
    "    make_model, make_node, set_model_props, make_tensor,\n",
    "    make_graph, make_tensor_value_info)\n",
    "from onnx.checker import check_model\n",
    "from onnx.reference import ReferenceEvaluator\n",
    "\n",
    "X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])\n",
    "A = make_tensor_value_info('A', TensorProto.FLOAT, [None, None])\n",
    "B = make_tensor_value_info('B', TensorProto.FLOAT, [None, None])\n",
    "Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None])\n",
    "node0 = make_node('EyeLike', ['A'], ['Eye'])\n",
    "node1 = make_node('Add', ['A', 'Eye'], ['A1'])\n",
    "node2 = make_node('MatMul', ['X', 'A1'], ['XA1'])\n",
    "node3 = make_node('Add', ['XA1', 'B'], ['Y'])\n",
    "graph = make_graph([node0, node1, node2, node3], 'lr', [X, A, B], [Y])\n",
    "onnx_model = make_model(graph)\n",
    "check_model(onnx_model)\n",
    "with open(\"linear_regression.onnx\", \"wb\") as f:\n",
    "    f.write(onnx_model.SerializeToString())\n",
    "\n",
    "sess = ReferenceEvaluator(onnx_model, verbose=2)\n",
    "\n",
    "x = numpy.random.randn(4, 2).astype(numpy.float32)\n",
    "a = numpy.random.randn(2, 2).astype(numpy.float32) / 10\n",
    "b = numpy.random.randn(1, 2).astype(numpy.float32)\n",
    "feeds = {'X': x, 'A': a, 'B': b}\n",
    "\n",
    "print(sess.run(None, feeds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xxx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
